#include <iostream>
#include "torch/script.h"
#include "torch/torch.h"
#include "opencv2/core.hpp"
#include "opencv2/imgproc.hpp"
#include "opencv2/highgui.hpp"
#include "opencv2/imgcodecs.hpp"
#include <vector>
#include <chrono>
#include <string>
#include <map>
using namespace cv;
using namespace std;

#define once_number 64

int main(int argc,char** argv)
{
  cout<<argv[once_number+1]<<endl;
  //分类(4种)
  map<int,string> mp;
  mp[1]= "Attack_free_dataset";
  mp[2]= "DoS_attack_dataset";
  mp[3]= "Fuzzy_attack_dataset";
  mp[4]= "Impersonation_attack_dataset";
  Mat PIL = Mat(once_number * 3, 16, CV_8UC1);
  int i, j;

  for (i = 0; i < once_number * 3; i++)
  {
    uchar* current = PIL.ptr<uchar>(i);
    for (j = 0; j < 16; j++)
      current[j] = 0;
  }
  //填色
  int k;
  for (i = 0; i < once_number; i++)
    for (j = 1; j < 4; j++)
    {
      k = argv[i+1][j]-'0';
      if (k == 0)
        PIL.ptr<uchar>(i * 3 + j - 1)[8] = 255;
      if (k > 0 && k <= 9)
        PIL.ptr<uchar>(i * 3 + j - 1)[k - 1] = 255;
      if (k > 10)
        PIL.ptr<uchar>(i * 3 + j - 1)[k - 40] = 255;
    }
  //生成图片(以线程序号为名字)
  char pic_name[100]="./";
  strcat(pic_name, argv[once_number+1]);
  strcat(pic_name, ".jpg");
  imwrite(pic_name, PIL);
  waitKey(0);

  //加载图片
  int img_size = 224;
  Mat image = imread(pic_name,0);
  if (image.empty())
    fprintf(stderr, "Can not load image\n");
  Size dsize = Size(224, 224);
  Mat dst;
  resize(image, dst, dsize, 0, 0);

  torch::Tensor dst_tensor=torch::tensor(at::ArrayRef<uint8_t>(dst.data, 1*dst.rows * dst.cols * 1)).view({1,dst.rows, dst.cols, 1});
  dst_tensor = dst_tensor.permute({0,3,1,2}).toType(torch::kFloat32).div(255);
  //cout<<dst_tensor<<endl;
  //cout<<"cudu support:"<< (torch::cuda::is_available()?"ture":"false")<<endl;

  //加载模型
  vector<torch::jit::IValue> inputs;
  inputs.push_back(dst_tensor);
  torch::jit::script::Module module = torch::jit::load("./AlexNet_IDS_new_del_0.979_e10_cpu.pt");
  module.eval();

  torch::Tensor outputs = module.forward(inputs).toTensor();
  inputs.pop_back();

  torch::Tensor pred = outputs.argmax(1);
  //cout<<outputs<<endl;
  cout<<argv[once_number+1]<<": "<<mp[pred.item().toInt()]<<endl;
  remove(pic_name);
  return 0;
}
